#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015-2020 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)


import collections
import math
import os
import backtrader as bt
from backtrader.analyzers.drawdown import NO_TIMEFRAME_DT_KEY
from . import TradeAnalyzer, TimeDrawDown, Calmar, SQN, SharpeRatio, AnnualReturn, MonthlyReturn, NetAssetValue, RangeReturn, BenchmarkReturn
from backtrader.utils.datehelper import *

"""
适合单品种指标分析
"""


class TradeIndicator:
    def __init__(self):
        self.win_rate = 0          # 胜率 _win_count / _trade_count
        self.trade_count = 0       # 交易次数
        self.win_count = 0         # 盈利次数, 去掉佣金之后盈利才算盈利
        self.loss_count = 0        # 亏损次数
        self.win_loss_ratio = 0    # 盈亏比  abs(avg_win / avg_loss)
        self.profit_factor = 0     # 盈利因子(每元风险得到的回报) abs((win_pnl / loss_pnl)
        self.profit_expect = 0     # 利润预期 (win_rate * win_loss_ratio - (1 - win_rate))
        self.trade_pnl = 0         # 净利润(扣掉佣金)
        self.win_pnl = 0           # 毛利润(扣掉佣金)
        self.loss_pnl = 0          # 毛亏损(扣掉佣金)
        self.trade_comm = 0        # 佣金
        self.avg_win = 0           # 平均盈利(扣掉佣金) win_pnl / win_count
        self.avg_loss = 0          # 平均亏损(扣掉佣金) loss_pnl / loss_count
        self.best_win = 0          # 单笔最大盈利
        self.worse_loss = 0        # 单笔最大亏损
        self.best_ret = 0          # 单笔最大收益率
        self.worse_ret = 0         # 单笔最大亏损收益率
        self._total_held = 0       # 最长持仓周期
        self.longest_held = 0      # 最长持仓周期
        self.avg_held = 0          # 平均持仓周期
        self.win_longest = 0       # 盈利最长周期
        self.win_longest_pnl = 0   # 连续盈利金额
        self.win_longest_idx = 0   # 盈利最长周期的时间点
        self.loss_longest = 0      # 亏损最长周期
        self.loss_longest_pnl = 0  # 连续亏损金额
        self.loss_longest_idx = 0   # 亏损最长周期的时间点
        self._current_held = 0     # 计算连续的当前持仓周期
        self._current_pnl = 0      # 计算连续的当前利润
        self._current_idx = 0      # 计算连续的当前利润交易时间点
        self._pre_pnl_status = 0   # 上次盈利状态, 盈利=1, 亏损=-1

    # closesize和closevalue在现货都是负数
    def update(self, idx, pnlcomm, comm, closesize, closeprice, closevalue, barlen, closetime):
        self.trade_count += 1

        # 总收益
        self.trade_pnl += pnlcomm
        self.trade_comm += comm

        pnl_status = 1 if pnlcomm > 0 else -1
        if pnl_status != self._pre_pnl_status and self._pre_pnl_status != 0:
            if self._pre_pnl_status > 0:
                if self._current_pnl > self.win_longest_pnl:
                    # 找到最大的
                    self.win_longest_idx = self._current_idx
                self.win_longest_pnl = max(self.win_longest_pnl, self._current_pnl)
                self.win_longest = max(self.win_longest, self._current_held)
            else:
                if self._current_pnl < self.loss_longest_pnl:
                    # 找到最大的
                    self.loss_longest_idx = self._current_idx
                self.loss_longest_pnl = min(self.loss_longest_pnl, self._current_pnl)
                self.loss_longest = max(self.loss_longest, self._current_held)
            self._current_pnl = pnlcomm
            self._current_held = 1
            self._current_idx = idx
        else:
            self._current_pnl += pnlcomm
            self._current_held += 1

        self._pre_pnl_status = pnl_status

        _ret = pnlcomm / -closevalue
        if pnlcomm > 0:
            self.win_count += 1
            self.win_pnl += pnlcomm
            self.best_win = max(self.best_win, pnlcomm)
            self.best_ret = max(self.best_ret, _ret)
        else:
            self.loss_count += 1
            self.loss_pnl += pnlcomm
            self.worse_loss = min(self.worse_loss, pnlcomm)
            self.worse_ret = min(self.worse_ret, _ret)

        self._total_held += barlen
        self.longest_held = max(self.longest_held, barlen)
        self.avg_held = self._total_held / self.trade_count
        self.win_rate = self.win_count / self.trade_count
        self.avg_win = self.win_pnl / self.win_count if self.win_count > 0 else 0
        self.avg_loss = self.loss_pnl / self.loss_count if self.loss_count > 0 else 0
        self.win_loss_ratio = abs(self.avg_win / self.avg_loss) if self.avg_loss != 0 else 0
        self.profit_factor = abs(self.win_pnl / self.loss_pnl) if self.loss_pnl != 0 else 0
        self.profit_expect = self.win_rate * self.win_loss_ratio - (1 - self.win_rate) if self.win_loss_ratio != 0 else 1

    def analysis(self):
        rs = collections.OrderedDict()
        for k, v in self.__dict__.copy().items():
            if not k.startswith("_"):
                rs[k] = v
        return rs


class SimpleReport(bt.Analyzer):
    params = (
        ('timeframe', bt.TimeFrame.Days),
        ('compression', 1),
        ('csv', False),
        ('out', None),
        ('rounding', 3)
    )

    def __init__(self):
        dtfcomp = dict(timeframe=self.p.timeframe,
                       compression=self.p.compression)
        self._init_value = self.strategy.broker.get_value()

        # self._trade_analysis = TradeAnalyzer()

        # 最大回撤比例、最大回撤时长
        self._mdd = TimeDrawDown(**dtfcomp)

        self._year_mdd = TimeDrawDown(timeframe=self.p.timeframe, compression=self.p.compression, dt_mode=bt.TimeFrame.Years)

        # sharp ratio

        self._sharp_ratio = SharpeRatio(**dtfcomp, annualize=True)
        # 收益率(年度/月度)

        self._yearly_rets = AnnualReturn()

        self._monthly_rets = MonthlyReturn()

        self._range_rets = RangeReturn()

        self._benchmark_rets = BenchmarkReturn()

        # 历史成交
        self._trade_pnl = list()

        # sql
        self._sqn = SQN()

        self.first_prices, self.last_prices, self.max_prices, self.mdds = {}, {}, {}, {}

    def cal_buy_hold(self):
        len1 = len(self.datas)
        rets = 0
        mdds = 0
        for k in self.first_prices.keys():
            if k in self.last_prices and k in self.first_prices:
                rets += (self.last_prices[k] / self.first_prices[k] - 1) * (1 / len1)
                mdds += self.mdds[k] * (1 / len1)

        return rets, mdds

    def next(self):
        for d in self.datas:
            symbol = d._name
            if symbol not in self.max_prices:
                continue
            if d.close[0] > self.max_prices[symbol]:
                self.max_prices[symbol] = d.close[0]
            else:
                dd = abs(d.close[0] / self.max_prices[symbol] - 1)
                self.mdds[symbol] = max(dd, self.mdds[symbol])

    def notify_trade(self, trade):
        symbol = trade.data._name
        if trade.isclosed:
            self._trade_pnl.append([trade.pnlcomm, trade.commission, trade.closesize, trade.closeprice, trade.closevalue, trade.barlen, trade.close_datetime()])

    def notify_order(self, order):
        symbol = order.data._name
        if order.status in [order.Completed]:
            if order.isbuy() and symbol not in self.first_prices:
                self.first_prices[symbol] = order.executed.price
                self.max_prices[symbol] = order.executed.price
                self.mdds[symbol] = 0
            if order.issell():
                self.last_prices[symbol] = order.executed.price

    def curr_dt_key(self, dt, dt_mode):
        if dt_mode == bt.TimeFrame.Months:
            dt_key = dt.year * 100 + dt.month
        elif dt_mode == bt.TimeFrame.Years:
            dt_key = dt.year
        else:
            dt_key = -1
        return dt_key

    def cal_trade(self, dt_mode=bt.TimeFrame.Months, filter_dt=None, key="all") -> collections.OrderedDict:

        trade_res = collections.OrderedDict()

        pre_dt_key = None
        trade_ind = TradeIndicator()

        for idx, item in enumerate(self._trade_pnl):
            _pnlcomm, _comm, closevalue, barlen, dt = item[0], item[1], item[4], item[5], item[-1]
            if filter_dt and dt <= filter_dt:
                continue
            dt_key = self.curr_dt_key(dt, dt_mode)
            if not pre_dt_key or dt_key == pre_dt_key or dt_key == -1:
                pre_dt_key = dt_key
            else:
                trade_res[pre_dt_key] = trade_ind.analysis()
                pre_dt_key = dt_key
                trade_ind = TradeIndicator()
            trade_ind.update(idx, _pnlcomm, _comm, None, None, closevalue, barlen, dt)

        if dt_mode == bt.TimeFrame.NoTimeFrame:
            pre_dt_key = key
        if pre_dt_key:
            trade_res[pre_dt_key] = trade_ind.analysis()
        return trade_res


    def stop(self):
        # 初始净值、最终净值
        _init_value = self._init_value
        _last_value = self.strategy.broker.get_value()
        _last_cash = self.strategy.broker.get_cash()

        bar_len = self.data.datetime.buflen()
        start_time = self.data.datetime.datetime(-(bar_len - 1))
        end_time = self.data.datetime.datetime(0)
        days = (end_time - start_time).days

        # Buy&Hold收益回撤
        _bh_ret, _bh_mdd = self.cal_buy_hold()

        params = self.strategy.p.__dict__.copy()

        # 收益风险比
        _ret = _last_value / _init_value - 1

        self.rets["total_ret"] = _ret
        self.rets["b&h_ret"] = _bh_ret * (float(params["size"]) / 100)
        self.rets["annual_ret"] = _ret / days * 365

        # 计算基准的各种类型的收益
        benchmark_rets = self._benchmark_rets.get_analysis()
        for k in benchmark_rets.keys():
            benchmark_rets[k] *= (float(params["size"]) / 100)
        benchmark_rets["all"] = self.rets["b&h_ret"]

        _mdd = self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY]["mdd"]
        _mddvalue = self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY]["mddvalue"]
        _mddlen = self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY]["mddlen"]
        _mddlen2 = self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY]["mddlen2"]
        _maxmddlen = self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY]["maxmddlen"]
        _maxmddlen2 = self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY]["maxmddlen2"]

        self.rets["rets_risk_ratio"] = self.rets["annual_ret"] / _mdd if _mdd else float('nan')
        self.rets["sharp"] = self._sharp_ratio.get_analysis()["sharperatio"]
        self.rets["sqn"] = self._sqn.get_analysis()["sqn"]

        # 最大回撤和最大回撤时长
        self.rets["mdd"] = _mdd
        self.rets["mddvalue"] = _mddvalue
        self.rets["mddlen"] = _mddlen           # 最大回撤周期内时长
        self.rets["mddlen2"] = _mddlen2  # 最大回撤周期内时长
        self.rets["maxmddlen"] = _maxmddlen         # 最长回撤时长
        self.rets["maxmddlen2"] = _maxmddlen2  # 最长回撤时长
        self.rets["b&hmdd"] = _bh_mdd * (float(params["size"]) / 100)

        _month_returns = self._monthly_rets.get_analysis()      # 月收益率
        _year_returns = self._yearly_rets.get_analysis()        # 年收益率
        _range_returns = self._range_rets.get_analysis()        # 阶段收益率

        # 总交易
        _all_trade = self.cal_trade(dt_mode=bt.TimeFrame.NoTimeFrame, key="all")
        for k, v in _all_trade["all"].items():
            self.rets[k] = v

        # -------------------分年统计-----------------
        _year_trade_stat = collections.OrderedDict()

        # 年收益率
        for date, _returns in _year_returns.items():
            if date not in _year_trade_stat:
                _year_trade_stat[date] = collections.OrderedDict()
            _year_trade_stat[date]["total_ret"] = _returns

        if "all" not in _year_trade_stat:
            _year_trade_stat["all"] = collections.OrderedDict()
        _year_trade_stat["all"]["total_ret"] = _ret

        # 填充年统计的基准
        for date in _year_trade_stat.keys():
            _year_trade_stat[date]["b&h_ret"] = benchmark_rets[date]

        # 年回撤
        mdd_analysis = self._year_mdd.get_analysis()
        if len(mdd_analysis) > 0:
            for year in mdd_analysis.keys():
                _year_trade_stat[year].update(mdd_analysis[year])
            _year_trade_stat["all"].update(self._mdd.get_analysis()[NO_TIMEFRAME_DT_KEY])

        # 年统计指标
        year_stat = self.cal_trade(dt_mode=bt.TimeFrame.Years)
        if len(year_stat) > 0:
            for year in year_stat.keys():
                _year_trade_stat[year].update(year_stat[year])
            _year_trade_stat["all"].update(_all_trade["all"])
        # -----------------------------------

        # 分月统计
        _month_trade_stat = collections.OrderedDict()
        for k, v in _month_returns.items():
            if k not in _month_trade_stat:
                _month_trade_stat[k] = collections.OrderedDict()
            _month_trade_stat[k]["total_ret"] = v
            _month_trade_stat[k].update(TradeIndicator().analysis())  # 初始化

        # 填充月统计的基准
        for date in _month_trade_stat.keys():
            _month_trade_stat[date]["b&h_ret"] = benchmark_rets[date]

        month_stat = self.cal_trade(dt_mode=bt.TimeFrame.Months)
        for k, v in month_stat.items():
            _month_trade_stat[k].update(v)
        # -----------------------------------

        # 分阶段统计, 最近一个月、最近三个月、最近半年、YTD

        _range_trade_stat = collections.OrderedDict()
        for k, v in _range_returns.items():
            if k not in _range_trade_stat:
                _range_trade_stat[k] = collections.OrderedDict()
            _range_trade_stat[k]["total_ret"] = v
            # _range_trade_stat[k].update(TradeIndicator().analysis())  # 初始化

        # 填充月统计的基准
        for date in _range_trade_stat.keys():
            _range_trade_stat[date]["b&h_ret"] = benchmark_rets[date]

        for m in ["1month", "3month", "6month", "ytd"]:
            range_date, _range_key = recent_date(end_time, m)
            _range_stat = self.cal_trade(dt_mode=bt.TimeFrame.NoTimeFrame, filter_dt=range_date, key=_range_key)
            for k, v in _range_stat.items():
                _range_trade_stat[k].update(v)

        m_returns = _month_returns.values()
        self.rets["best_month_ret"] = max(m_returns)
        self.rets["worst_month_ret"] = min(m_returns)
        self.rets["avg_month_ret"] = sum(m_returns) / len(m_returns)
        self.rets["win_month_count"] = len(list(filter(lambda x: x > 0, m_returns)))
        self.rets["lost_month_count"] = len(list(filter(lambda x: x <= 0, m_returns)))

        # 年收益, 月收益
        for date, _returns in _year_returns.items():
            self.rets[date] = _returns

        self.rets["init_cash"] = _init_value
        self.rets["last_value"] = _last_value
        self.rets["last_cash"] = _last_cash

        self.rets["start_date"] = int(date2str(start_time, _format_str="%Y%m%d"))
        self.rets["ended_date"] = int(date2str(end_time, _format_str="%Y%m%d"))

        # 把策略参数写入报告
        pp = self.strategy.p.__dict__.copy()
        if "timeframe" in pp:
            del pp["timeframe"]
        self.rets.update(pp)

        if self.p.csv and self.p.out:

            params = self.get_params()
            report_dir = "_".join(map(lambda x: str(x), params.values()))
            out = os.path.join(self.p.out, report_dir)

            if not os.path.exists(out):
                os.makedirs(out)

            fpath = os.path.join(out, "report.csv")
            import pandas as pd
            df = pd.DataFrame.from_dict(self.rets, orient='index', columns=["value"])
            df.index.name = "ind"
            df["symbol"] = self.data._name
            df = df.reset_index()
            df = df[["symbol", "ind", "value"]]
            df = df.round(decimals=self.p.rounding)
            df.to_csv(fpath, index=False)

            cols = ["total_ret", "b&h_ret", "win_rate", "trade_count", "win_count", "loss_count",
                    "trade_pnl", "win_pnl", "loss_pnl", "trade_comm", "avg_win", "avg_loss", "profit_factor", "profit_expect"]
            # 月收益表
            fpath2 = os.path.join(out, "bymonth.csv")
            df2 = pd.DataFrame.from_dict(_month_trade_stat).T
            df2 = df2[cols]
            df2.index.name = "month"
            df2 = df2.reset_index()
            df2 = df2.fillna(0)
            df2 = df2.round(decimals=self.p.rounding + 1)

            for col in df2.columns:
                if col in ['month']:
                    continue
                pad = df2[col].astype(str).str.len().max()
                if pad < 5:
                    pad = 5
                df2[col] = df2[col].astype(str).str.ljust(pad, "0")

            df2.to_csv(fpath2, index=False)

            # 年交易统计表
            fpath3 = os.path.join(out, "byyear.csv")
            df3 = pd.DataFrame.from_dict(_year_trade_stat)
            df3 = df3.loc[[e for e in df3.index if e != "dtype"]]
            df3 = df3.astype(float, copy=True, errors='raise')
            cc = df3.columns
            df3.index.name = "ind"
            df3["symbol"] = self.data._name
            df3 = df3.reset_index()
            df3 = df3[["symbol", "ind"] + [e for e in cc]]
            df3 = df3.fillna(0)
            df3 = df3.round(decimals=self.p.rounding)

            for col in df3.columns:
                if col in ['symbol']:
                    continue
                pad = df3[col].astype(str).str.len().max()
                df3[col] = df3[col].astype(str).str.rjust(pad, " ")

            df3.to_csv(fpath3, index=False)

            # 阶段交易统计表
            fpath4 = os.path.join(out, "byrange.csv")
            df4 = pd.DataFrame.from_dict(_range_trade_stat)
            df4 = df4.loc[[e for e in df4.index if e != "dtype"]]
            df4 = df4.astype(float, copy=True, errors='raise')
            cc = df4.columns
            df4.index.name = "ind"
            df4["symbol"] = self.data._name
            df4 = df4.reset_index()
            df4 = df4[["symbol", "ind"] + [e for e in cc]]
            df4 = df4.fillna(0)
            df4 = df4.round(decimals=self.p.rounding)

            for col in df4.columns:
                if col in ['symbol']:
                    continue
                pad = df4[col].astype(str).str.len().max()
                df4[col] = df4[col].astype(str).str.rjust(pad, " ")

            df4.to_csv(fpath4, index=False)

